package top.fullj.eventbus;

import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;

/**
 * @author bruce.wu
 * @since 2022/2/24 15:16
 */
final class Registry {

    private static final Map<Class<?>, List<Method>> METHOD_CACHE = new ConcurrentHashMap<>();

    private final Map<Class<?>, CopyOnWriteArraySet<Subscriber>> subscriberMap = new ConcurrentHashMap<>();

    Set<Subscriber> getSubscribers(Class<?> eventType) {
        CopyOnWriteArraySet<Subscriber> subscribers = subscriberMap.get(eventType);
        if (subscribers == null) {
            return Collections.emptySet();
        }
        return new HashSet<>(subscribers);
    }

    void register(Object listener, StickySubscribeHandler handler) {
        Class<?> subscriberType = listener.getClass();
        for (Method method : getCachingSubscriberMethods(subscriberType)) {
            Subscriber subscriber = new Subscriber(listener, method);
            boolean absent = subscribe(subscriber);
            if (subscriber.sticky && absent) {
                handler.call(subscriber);
            }
        }
    }

    void unregister(Object listener) {
        Class<?> subscriberType = listener.getClass();
        for (Method method : getCachingSubscriberMethods(subscriberType)) {
            unsubscribe(listener, method);
        }
    }

    boolean subscribe(Subscriber subscriber) {
        Class<?> eventType = subscriber.eventType;
        CopyOnWriteArraySet<Subscriber> subscribers = subscriberMap.get(eventType);
        if (subscribers == null) {
            synchronized (subscriberMap) {
                subscribers = Maps.putIfAbsent(subscriberMap, eventType, new CopyOnWriteArraySet<>());
            }
        }
        return subscribers.add(subscriber);
    }

    void unsubscribe(Subscriber subscriber) {
        Class<?> eventType = subscriber.eventType;
        CopyOnWriteArraySet<Subscriber> subscribers = subscriberMap.get(eventType);
        if (subscribers != null) {
            subscribers.remove(subscriber);
        }
    }

    void unsubscribe(Object listener, Method method) {
        unsubscribe(new Subscriber(listener, method));
    }

    static List<Method> getCachingSubscriberMethods(Class<?> type) {
        if (METHOD_CACHE.containsKey(type)) {
            return METHOD_CACHE.get(type);
        }
        List<Method> methods = getAnnotatedSubscriberMethods(type);
        METHOD_CACHE.put(type, methods);
        return methods;
    }

    private static Class<?> getStopClass(Class<?> type, Class<?> childStop) {
        if (childStop != null)
            return childStop;
        EventListener listener = type.getAnnotation(EventListener.class);
        return (listener == null) ?  null : listener.stopClass();
    }

    /**
     * find all method annotated by @Subscribe
     *  that defined in the class tree
     */
    private static List<Method> getAnnotatedSubscriberMethods(Class<?> type) {
        Map<MethodIdentifier, Method> identifiers = new HashMap<>();

        Class<?> stop = getStopClass(type, null);

        for (Class<?> c = type; c != stop && c != Object.class;
             c = c.getSuperclass(), stop = getStopClass(c, stop)) {
            for (Method method : c.getDeclaredMethods()) {
                if (method.isAnnotationPresent(Subscribe.class) /*&& !method.isSynthetic()*/) {
                    Class<?>[] argTypes = method.getParameterTypes();
                    if (argTypes.length != 1) {
                        throw new IllegalArgumentException(
                                String.format("@Subscribe method %s must have exactly 1 parameter",
                                        method.getName()));
                    }
                    MethodIdentifier identifier = new MethodIdentifier(method);
                    if (!identifiers.containsKey(identifier)) {
                        identifiers.put(identifier, method);
                    }
                }
            }
        }

        if (identifiers.isEmpty()) {
            throw new IllegalArgumentException("@Subscribe method required in class: " + type.getSimpleName());
        }

        return new LinkedList<>(identifiers.values());
    }

    private static final class MethodIdentifier {

        private final String name;

        private final Class<?> eventType;

        MethodIdentifier(Method method) {
            this.name = method.getName();
            this.eventType = method.getParameterTypes()[0];
        }

        @Override
        public int hashCode() {
            return (31 + name.hashCode()) * 31 + eventType.hashCode();
        }

        @Override
        public boolean equals(Object obj) {
            if (obj instanceof MethodIdentifier) {
                MethodIdentifier other = (MethodIdentifier) obj;
                return name.equals(other.name) && eventType.equals(other.eventType);
            }
            return false;
        }
    }

}
